{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Column names:\n",
      "['State', 'Account Length', 'Area Code', 'Phone', \"Int'l Plan\", 'VMail Plan', 'VMail Message', 'Day Mins', 'Day Calls', 'Day Charge', 'Eve Mins', 'Eve Calls', 'Eve Charge', 'Night Mins', 'Night Calls', 'Night Charge', 'Intl Mins', 'Intl Calls', 'Intl Charge', 'CustServ Calls', 'Churn?']\n",
      "\n",
      "Sample data:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>State</th>\n",
       "      <th>Account Length</th>\n",
       "      <th>Area Code</th>\n",
       "      <th>Phone</th>\n",
       "      <th>Int'l Plan</th>\n",
       "      <th>VMail Plan</th>\n",
       "      <th>Night Charge</th>\n",
       "      <th>Intl Mins</th>\n",
       "      <th>Intl Calls</th>\n",
       "      <th>Intl Charge</th>\n",
       "      <th>CustServ Calls</th>\n",
       "      <th>Churn?</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>KS</td>\n",
       "      <td>128</td>\n",
       "      <td>415</td>\n",
       "      <td>382-4657</td>\n",
       "      <td>no</td>\n",
       "      <td>yes</td>\n",
       "      <td>11.01</td>\n",
       "      <td>10.0</td>\n",
       "      <td>3</td>\n",
       "      <td>2.70</td>\n",
       "      <td>1</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OH</td>\n",
       "      <td>107</td>\n",
       "      <td>415</td>\n",
       "      <td>371-7191</td>\n",
       "      <td>no</td>\n",
       "      <td>yes</td>\n",
       "      <td>11.45</td>\n",
       "      <td>13.7</td>\n",
       "      <td>3</td>\n",
       "      <td>3.70</td>\n",
       "      <td>1</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NJ</td>\n",
       "      <td>137</td>\n",
       "      <td>415</td>\n",
       "      <td>358-1921</td>\n",
       "      <td>no</td>\n",
       "      <td>no</td>\n",
       "      <td>7.32</td>\n",
       "      <td>12.2</td>\n",
       "      <td>5</td>\n",
       "      <td>3.29</td>\n",
       "      <td>0</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OH</td>\n",
       "      <td>84</td>\n",
       "      <td>408</td>\n",
       "      <td>375-9999</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>8.86</td>\n",
       "      <td>6.6</td>\n",
       "      <td>7</td>\n",
       "      <td>1.78</td>\n",
       "      <td>2</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>OK</td>\n",
       "      <td>75</td>\n",
       "      <td>415</td>\n",
       "      <td>330-6626</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>8.41</td>\n",
       "      <td>10.1</td>\n",
       "      <td>3</td>\n",
       "      <td>2.73</td>\n",
       "      <td>3</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>AL</td>\n",
       "      <td>118</td>\n",
       "      <td>510</td>\n",
       "      <td>391-8027</td>\n",
       "      <td>yes</td>\n",
       "      <td>no</td>\n",
       "      <td>9.18</td>\n",
       "      <td>6.3</td>\n",
       "      <td>6</td>\n",
       "      <td>1.70</td>\n",
       "      <td>0</td>\n",
       "      <td>False.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  State  Account Length  Area Code     Phone Int'l Plan VMail Plan  \\\n",
       "0    KS             128        415  382-4657         no        yes   \n",
       "1    OH             107        415  371-7191         no        yes   \n",
       "2    NJ             137        415  358-1921         no         no   \n",
       "3    OH              84        408  375-9999        yes         no   \n",
       "4    OK              75        415  330-6626        yes         no   \n",
       "5    AL             118        510  391-8027        yes         no   \n",
       "\n",
       "   Night Charge  Intl Mins  Intl Calls  Intl Charge  CustServ Calls  Churn?  \n",
       "0         11.01       10.0           3         2.70               1  False.  \n",
       "1         11.45       13.7           3         3.70               1  False.  \n",
       "2          7.32       12.2           5         3.29               0  False.  \n",
       "3          8.86        6.6           7         1.78               2  False.  \n",
       "4          8.41       10.1           3         2.73               3  False.  \n",
       "5          9.18        6.3           6         1.70               0  False.  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from __future__ import division\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "churn_df = pd.read_csv('churn.csv')\n",
    "col_names = churn_df.columns.tolist()\n",
    "\n",
    "print \"Column names:\"\n",
    "print col_names\n",
    "\n",
    "to_show = col_names[:6] + col_names[-6:]\n",
    "\n",
    "print \"\\nSample data:\"\n",
    "churn_df[to_show].head(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature space holds 3333 observations and 17 features\n",
      "Unique target labels: [0 1]\n",
      "[ 0.67648946 -0.32758048  1.6170861   1.23488274  1.56676695  0.47664315\n",
      "  1.56703625 -0.07060962 -0.05594035 -0.07042665  0.86674322 -0.46549436\n",
      "  0.86602851 -0.08500823 -0.60119509 -0.0856905  -0.42793202]\n",
      "2850\n"
     ]
    }
   ],
   "source": [
    "churn_result = churn_df['Churn?']\n",
    "y = np.where(churn_result == 'True.',1,0)\n",
    "\n",
    "# We don't need these columns\n",
    "to_drop = ['State','Area Code','Phone','Churn?']\n",
    "churn_feat_space = churn_df.drop(to_drop,axis=1)\n",
    "\n",
    "# 'yes'/'no' has to be converted to boolean values\n",
    "# NumPy converts these from boolean to 1. and 0. later\n",
    "yes_no_cols = [\"Int'l Plan\",\"VMail Plan\"]\n",
    "churn_feat_space[yes_no_cols] = churn_feat_space[yes_no_cols] == 'yes'\n",
    "\n",
    "# Pull out features for future use\n",
    "features = churn_feat_space.columns\n",
    "\n",
    "X = churn_feat_space.as_matrix().astype(np.float)\n",
    "\n",
    "# This is important\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(X)\n",
    "\n",
    "print \"Feature space holds %d observations and %d features\" % X.shape\n",
    "print \"Unique target labels:\", np.unique(y)\n",
    "print X[0]\n",
    "print len(y[y == 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.cross_validation import KFold\n",
    "\n",
    "def run_cv(X,y,clf_class,**kwargs):\n",
    "    # Construct a kfolds object\n",
    "    kf = KFold(len(y),n_folds=5,shuffle=True)\n",
    "    y_pred = y.copy()\n",
    "\n",
    "    # Iterate through folds\n",
    "    for train_index, test_index in kf:\n",
    "        X_train, X_test = X[train_index], X[test_index]\n",
    "        y_train = y[train_index]\n",
    "        # Initialize a classifier with key word arguments\n",
    "        clf = clf_class(**kwargs)\n",
    "        clf.fit(X_train,y_train)\n",
    "        y_pred[test_index] = clf.predict(X_test)\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Support vector machines:\n",
      "0.916\n",
      "Random forest:\n",
      "0.944\n",
      "K-nearest-neighbors:\n",
      "0.893\n"
     ]
    }
   ],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.ensemble import RandomForestClassifier as RF\n",
    "from sklearn.neighbors import KNeighborsClassifier as KNN\n",
    "\n",
    "def accuracy(y_true,y_pred):\n",
    "    # NumPy interprets True and False as 1. and 0.\n",
    "    return np.mean(y_true == y_pred)\n",
    "\n",
    "print \"Support vector machines:\"\n",
    "print \"%.3f\" % accuracy(y, run_cv(X,y,SVC))\n",
    "print \"Random forest:\"\n",
    "print \"%.3f\" % accuracy(y, run_cv(X,y,RF))\n",
    "print \"K-nearest-neighbors:\"\n",
    "print \"%.3f\" % accuracy(y, run_cv(X,y,KNN))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def run_prob_cv(X, y, clf_class, **kwargs):\n",
    "    kf = KFold(len(y), n_folds=5, shuffle=True)\n",
    "    y_prob = np.zeros((len(y),2))\n",
    "    for train_index, test_index in kf:\n",
    "        X_train, X_test = X[train_index], X[test_index]\n",
    "        y_train = y[train_index]\n",
    "        clf = clf_class(**kwargs)\n",
    "        clf.fit(X_train,y_train)\n",
    "        # Predict probabilities, not classes\n",
    "        y_prob[test_index] = clf.predict_proba(X_test)\n",
    "    return y_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pred_prob</th>\n",
       "      <th>count</th>\n",
       "      <th>true_prob</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1748</td>\n",
       "      <td>0.026316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.1</td>\n",
       "      <td>733</td>\n",
       "      <td>0.028649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.2</td>\n",
       "      <td>256</td>\n",
       "      <td>0.062500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.3</td>\n",
       "      <td>120</td>\n",
       "      <td>0.150000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.9</td>\n",
       "      <td>78</td>\n",
       "      <td>0.974359</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.4</td>\n",
       "      <td>76</td>\n",
       "      <td>0.421053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.7</td>\n",
       "      <td>76</td>\n",
       "      <td>0.947368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.8</td>\n",
       "      <td>73</td>\n",
       "      <td>0.931507</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.6</td>\n",
       "      <td>67</td>\n",
       "      <td>0.701493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1.0</td>\n",
       "      <td>57</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.5</td>\n",
       "      <td>49</td>\n",
       "      <td>0.612245</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    pred_prob  count  true_prob\n",
       "0         0.0   1748   0.026316\n",
       "1         0.1    733   0.028649\n",
       "2         0.2    256   0.062500\n",
       "3         0.3    120   0.150000\n",
       "4         0.9     78   0.974359\n",
       "5         0.4     76   0.421053\n",
       "6         0.7     76   0.947368\n",
       "7         0.8     73   0.931507\n",
       "8         0.6     67   0.701493\n",
       "9         1.0     57   1.000000\n",
       "10        0.5     49   0.612245"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Use 10 estimators so predictions are all multiples of 0.1\n",
    "pred_prob = run_prob_cv(X, y, RF, n_estimators=10)\n",
    "#print pred_prob[0]\n",
    "pred_churn = pred_prob[:,1]\n",
    "is_churn = y == 1\n",
    "\n",
    "# Number of times a predicted probability is assigned to an observation\n",
    "counts = pd.value_counts(pred_churn)\n",
    "#print counts\n",
    "\n",
    "# calculate true probabilities\n",
    "true_prob = {}\n",
    "for prob in counts.index:\n",
    "    true_prob[prob] = np.mean(is_churn[pred_churn == prob])\n",
    "    true_prob = pd.Series(true_prob)\n",
    "\n",
    "# pandas-fu\n",
    "counts = pd.concat([counts,true_prob], axis=1).reset_index()\n",
    "counts.columns = ['pred_prob', 'count', 'true_prob']\n",
    "counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
